d7fc61
@@ -35,6 +35,7 @@
import org.infinispan.context.Flag;
 import org.infinispan.notifications.cachelistener.annotation.CacheEntryCreated;
 import org.infinispan.notifications.cachelistener.annotation.CacheEntryModified;
 import org.infinispan.notifications.cachelistener.event.CacheEntryEvent;
+import org.infinispan.util.concurrent.ConcurrentHashSet;
 import org.jboss.msc.service.ServiceName;
 import org.wildfly.clustering.dispatcher.CommandDispatcher;
 import org.wildfly.clustering.ee.Batch;
@@ -97,22 +98,15 @@
public class CacheServiceProviderRegistry<T> implements ServiceProviderRegistry<
         if (this.listeners.putIfAbsent(service, listener) != null) {
             throw new IllegalArgumentException(service.toString());
         }
-        final Node node = this.group.getLocalNode();
-        Set<Node> nodes = new HashSet<>(Collections.singleton(node));
         try (Batch batch = this.batcher.createBatch()) {
-            Set<Node> existing = this.cache.getAdvancedCache().withFlags(Flag.FORCE_SYNCHRONOUS).putIfAbsent(service, nodes);
-            if (existing != null) {
-                if (existing.add(node)) {
-                    this.cache.getAdvancedCache().withFlags(Flag.IGNORE_RETURN_VALUES).replace(service, existing);
-                }
-            }
+            this.register(this.group.getLocalNode(), service);
         }
         return new AbstractServiceProviderRegistration<T>(service, this) {
             @Override
             public void close() {
                 Node node = CacheServiceProviderRegistry.this.getGroup().getLocalNode();
                 try (Batch batch = CacheServiceProviderRegistry.this.batcher.createBatch()) {
-                    Set<Node> nodes = CacheServiceProviderRegistry.this.cache.get(service);
+                    Set<Node> nodes = CacheServiceProviderRegistry.this.cache.getAdvancedCache().withFlags(Flag.FORCE_WRITE_LOCK).get(service);
                     if ((nodes != null) && nodes.remove(node)) {
                         Cache<T, Set<Node>> cache = CacheServiceProviderRegistry.this.cache.getAdvancedCache().withFlags(Flag.IGNORE_RETURN_VALUES);
                         if (nodes.isEmpty()) {
@@ -121,12 +115,24 @@
public class CacheServiceProviderRegistry<T> implements ServiceProviderRegistry<
                             cache.replace(service, nodes);
                         }
                     }
+                } finally {
+                    CacheServiceProviderRegistry.this.listeners.remove(service);
                 }
-                CacheServiceProviderRegistry.this.listeners.remove(service);
             }
         };
     }
 
+    void register(Node node, T service) {
+        Set<Node> nodes = new ConcurrentHashSet<>();
+        nodes.add(node);
+        Set<Node> existing = this.cache.getAdvancedCache().withFlags(Flag.FORCE_SYNCHRONOUS).putIfAbsent(service, nodes);
+        if (existing != null) {
+            if (existing.add(node)) {
+                this.cache.getAdvancedCache().withFlags(Flag.IGNORE_RETURN_VALUES).replace(service, existing);
+            }
+        }
+    }
+
     @Override
     public Set<Node> getProviders(final T service) {
         Set<Node> nodes = this.cache.get(service);
@@ -163,15 +169,8 @@
public class CacheServiceProviderRegistry<T> implements ServiceProviderRegistry<
                     try (Batch batch = this.batcher.createBatch()) {
                         for (Node node: newNodes) {
                             // Re-assert services for new members following merge since these may have been lost following split
-                            List<T> services = CacheServiceProviderRegistry.this.getServices(node);
-                            for (T service: services) {
-                                Set<Node> nodes = new HashSet<>(Collections.singleton(node));
-                                Set<Node> existing = this.cache.getAdvancedCache().withFlags(Flag.FORCE_SYNCHRONOUS).putIfAbsent(service, nodes);
-                                if (existing != null) {
-                                    if (existing.add(node)) {
-                                        this.cache.getAdvancedCache().withFlags(Flag.IGNORE_RETURN_VALUES).replace(service, existing);
-                                    }
-                                }
+                            for (T service: CacheServiceProviderRegistry.this.getServices(node)) {
+                                this.register(node, service);
                             }
                         }
                     }
